﻿#pragma kernel SortFluidData

#pragma kernel Emit
#pragma kernel EmitShape
#pragma kernel CopyAliveCount
#pragma kernel Update
#pragma kernel Integrate
#pragma kernel Copy

#pragma kernel Sort
#pragma kernel ClearMesh
#pragma kernel BuildMesh

#include "InterlockedUtils.cginc"
#include "MathUtils.cginc"
#include "GridUtils.cginc"
#include "Simplex.cginc"
#include "Bounds.cginc"
#include "SolverParameters.cginc"
#include "FluidKernels.cginc"

StructuredBuffer<int> sortedToOriginal;
RWStructuredBuffer<float4> sortedPositions;
RWStructuredBuffer<float4> sortedVelocities;
RWStructuredBuffer<float4> sortedAngularVelocities;
RWStructuredBuffer<quaternion> sortedOrientations;
RWStructuredBuffer<float4> sortedRadii;

StructuredBuffer<uint> cellOffsets;    // start of each cell in the sorted item array.
StructuredBuffer<uint> cellCounts;     // number of item in each cell.
StructuredBuffer<int> gridHashToSortedIndex; 
StructuredBuffer<aabb> solverBounds;

StructuredBuffer<uint> fluidSimplices;
StructuredBuffer<int> activeParticles;
StructuredBuffer<float4> positions;
StructuredBuffer<float4> orientations;
StructuredBuffer<float4> velocities;
RWStructuredBuffer<float4> angularVelocities;
StructuredBuffer<float4> principalRadii;
StructuredBuffer<float4> fluidMaterial;
StructuredBuffer<float4> fluidData;

StructuredBuffer<float4> inputPositions; // w component is distance traveled inside volume (approximate volumetric lighting).
StructuredBuffer<float4> inputVelocities; // w component is buoyancy
StructuredBuffer<float4> inputColors;     // rgba diffuse color
StructuredBuffer<float4> inputAttributes; // currentlifetime, maxlifetime, size, drag

RWStructuredBuffer<float4> outputPositions;
RWStructuredBuffer<float4> outputVelocities;
RWStructuredBuffer<float4> outputColors;
RWStructuredBuffer<float4> outputAttributes;

RWStructuredBuffer<uint> dispatch;
RWByteAddressBuffer vertices;
RWByteAddressBuffer indices;

// Variables set from the CPU
uint activeParticleCount;
uint maxFoamParticles;
uint particlesToEmit;

uint emitterShape;
float4 emitterPosition;
quaternion emitterRotation;
float4 emitterSize;

uint minFluidNeighbors;
float2 vorticityRange;
float2 velocityRange;
float foamGenerationRate;
float potentialIncrease;
float potentialDiffusion;

float advectionRadius;
float lifetime;
float lifetimeRandom;
float particleSize;
float buoyancy;
float drag;
float airDrag;
float sizeRandom;
float isosurface;
float airAging;
float3 agingOverPopulation;
float4 foamColor;
float4 sortAxis;

const uint groupWidth;
const uint groupHeight;
const uint stepIndex;

float deltaTime;
float randomSeed;

static const int4 quadrantOffsets[] =
{
    int4(0, 0, 0, 1),
    int4(1, 0, 0, 1),
    int4(0, 1, 0, 1),
    int4(1, 1, 0, 1),
    int4(0, 0, 1, 1),
    int4(1, 0, 1, 1),
    int4(0, 1, 1, 1),
    int4(1, 1, 1, 1)
};

void RandomInCylinder(float seed, float4 pos, float4 dir, float len, float radius, out float4 position, out float3 velocity)
{
    float3 rand = hash31(seed);

    float3 b1 = dir.xyz;
    float3 b2 = normalizesafe(cross(b1, float3(1,0,0)));
    float3 b3 = cross(b2, b1);
    
    float theta = rand.y * 2 * PI;
    float2 disc = radius * sqrt(rand.x) * float2(cos(theta),sin(theta));

    velocity = b2 * disc.x + b3 * disc.y;
    position = float4(pos.xyz + b1 * len * rand.z + velocity,0); 
}

void RandomInBox(float seed, float4 center, float4 size, out float4 position, out float3 velocity)
{
    float3 rand = hash31(seed);
    velocity = (rand - float3(0.5,0.5,0.5)) * size.xyz;
    position = float4(center.xyz + velocity,0); 
}

[numthreads(128, 1, 1)]
void SortFluidData (uint3 id : SV_DispatchThreadID)
{
    unsigned int i = id.x;
    if (i >= dispatch[3]) return;
    
    int original = sortedToOriginal[i];
    sortedPositions[i] = positions[original];
    sortedVelocities[i] = velocities[original];
    sortedAngularVelocities[i] = float4(angularVelocities[original].xyz,0);
    sortedOrientations[i] = orientations[original];
    sortedRadii[i] = principalRadii[original];
}

[numthreads(128, 1, 1)]
void EmitShape (uint3 id : SV_DispatchThreadID)
{
    uint i = id.x;
    if (i >= particlesToEmit) return;
    
    // atomically increment alive particle counter:
    uint count;
    InterlockedAdd(dispatch[3], 1, count);

    if (count < maxFoamParticles)
    {
        // initialize foam particle in a random position inside the cylinder spawned by fluid particle:
        float3 radialVelocity = float3(0,0,0);

        if (emitterShape == 0)
            RandomInCylinder(randomSeed + i, -float4(0,1,0,0)*emitterSize.y*0.5, float4(0,1,0,0), emitterSize.y, max(emitterSize.x, emitterSize.z) * 0.5, outputPositions[count], radialVelocity);
        else
            RandomInBox(randomSeed + i, FLOAT4_ZERO, emitterSize, outputPositions[count], radialVelocity);
            
        float2 random = hash21(randomSeed - i);

        // calculate initial life/size/color:
        float initialLife = max(0, lifetime - lifetime * random.x * lifetimeRandom);
        float initialSize = particleSize - particleSize * random.y * sizeRandom;
        
        outputPositions[count] = float4(emitterPosition.xyz + rotate_vector(emitterRotation, outputPositions[count].xyz),0);
        outputVelocities[count] = float4(0,0,0, buoyancy);
        outputColors[count] = foamColor;
        outputAttributes[count] = float4(1, 1/initialLife,initialSize,PackFloatRGBA(float4(airAging / 50.0, airDrag, drag, isosurface)));
    }
}

[numthreads(128, 1, 1)]
void Emit (uint3 id : SV_DispatchThreadID)
{
    uint i = id.x;
    if (i >= activeParticleCount) return;

    int p = activeParticles[i];

    float4 angVel = angularVelocities[p];
    float2 potential = UnpackFloatRG(angVel.w);

    // calculate fluid potential for foam generation:
    float vorticityPotential = Remap01(fluidData[p].z, vorticityRange.x, vorticityRange.y); 
    float velocityPotential = Remap01(length(velocities[p]), velocityRange.x, velocityRange.y);
    float potentialDelta = velocityPotential * vorticityPotential * deltaTime * potentialIncrease;

    // update foam potential:
    potential.y = saturate(potential.y * potentialDiffusion + potentialDelta);

    // calculate amount of emitted particles
    potential.x += foamGenerationRate * potential.y * deltaTime;
    int emitCount = (int)potential.x;
    potential.x -= emitCount;
    
    for (int j = 0; j < emitCount; ++j)
    {
        // atomically increment alive particle counter:
        uint count;
        InterlockedAdd(dispatch[3], 1, count);

        if (count < maxFoamParticles)
        {
            // initialize foam particle in a random position inside the cylinder spawned by fluid particle:
            float3 radialVelocity;
            RandomInCylinder(randomSeed + p + j, positions[p], normalizesafe(velocities[p]), length(velocities[p]) * deltaTime, principalRadii[p].x, outputPositions[count], radialVelocity);
            
            // calculate initial life/size/color:
            float2 random = hash21(randomSeed - p - j);
            float initialLife = max(0, potential.y * (lifetime - lifetime * random.x * lifetimeRandom));
            float initialSize = particleSize - particleSize * random.y * sizeRandom;

            outputVelocities[count] = velocities[p] + float4(radialVelocity, buoyancy);
            outputColors[count] = foamColor;
            outputAttributes[count] = float4(1, 1/initialLife,initialSize,PackFloatRGBA(float4(airAging / 50.0, airDrag, drag, isosurface)));
        }
    }

    angVel.w = PackFloatRG(potential);
    angularVelocities[p] = angVel;
}

[numthreads(1, 1, 1)]
void CopyAliveCount (uint3 id : SV_DispatchThreadID)
{
    dispatch[0] = dispatch[3] / 128 + 1;
    dispatch[8] = dispatch[3];
    dispatch[4] = dispatch[7] = 0;
}

[numthreads(128, 1, 1)]
void Update (uint3 id : SV_DispatchThreadID)
{
    uint i = id.x;
    if (i >= dispatch[8]) return; 
    
    uint count;
    InterlockedAdd(dispatch[3], -1, count);
    count--;

    if (count < maxFoamParticles && inputAttributes[count].x > 0)
    {
        uint aliveCount;
        InterlockedAdd(dispatch[7], 1, aliveCount);
        InterlockedMax(dispatch[4],(aliveCount + 1) / 128 + 1);

        float4 attributes = inputAttributes[count];
        float4 packedData = UnpackFloatRGBA(attributes.w);

        int offsetCount = (mode == 1) ? 4 : 8;
        float4 advectedVelocity = FLOAT4_ZERO;
        float4 advectedAngVelocity = FLOAT4_ZERO;
        float kernelSum = -packedData.w;
        uint neighbourCount = 0;

        float4 diffusePos = inputPositions[count];

        for (uint m = 1; m <= levelPopulation[0]; ++m)
        {
            uint l = levelPopulation[m];
            float radius = CellSizeOfLevel(l);
            float interactionDist = radius * 0.5;

            float4 cellCoords = floor((diffusePos - solverBounds[0].min_) / radius);

            cellCoords[3] = 0;
            if (mode == 1)
                cellCoords[2] = 0;

            float4 posInCell = diffusePos - (solverBounds[0].min_ + cellCoords * radius + float4(interactionDist,interactionDist,interactionDist,0));
            int4 quadrant = (int4)sign(posInCell);
            quadrant[3] = l;
           
            for (int j = 0; j < offsetCount; ++j)
            {
                int4 neighborCoord = (int4)cellCoords + quadrantOffsets[j] * quadrant;
                int cellIndex = gridHashToSortedIndex[GridHash(neighborCoord)];
                uint n = cellOffsets[cellIndex]; 
                uint end = n + cellCounts[cellIndex];

                for (;n < end; ++n)
                {
                    uint p = fluidSimplices[n];

                    int4 particleCoord = int4(floor((positions[p].xyz - solverBounds[0].min_.xyz)/ radius).xyz,l);
                    if (any (particleCoord - neighborCoord))
                        continue;

                    float4 normal = diffusePos - positions[p];
                    normal[3] = 0;
                    if (mode == 1)
                        normal[2] = 0;

                    float d = length(normal);
                    if (d <= interactionDist)
                    {
                        float3 radii = fluidMaterial[p].x * (principalRadii[p].xyz / principalRadii[p].x);

                        float4 angVel = float4(cross(angularVelocities[p].xyz, normal.xyz),0);
                        advectedAngVelocity += angVel * Poly6(d, radii.x) / Poly6(0, radii.x);
                        
                        normal.xyz = rotate_vector(q_conj(orientations[p]), normal.xyz) / radii;
                        d = length(normal.xyz) * radii.x;

                        // scale by volume (* 1 / normalized density)
                        float w = Poly6(d, radii.x) / fluidData[p].x;

                        kernelSum += w;
                        advectedVelocity += float4(velocities[p].xyz,0) * w;
                        neighbourCount++;
                    }
                }
            }
        }
        
        float4 forces = FLOAT4_ZERO;
        float velocityScale = 1;
        float agingScale = 1 + Remap01(dispatch[8] / (float)maxFoamParticles,agingOverPopulation.x,agingOverPopulation.y) * (agingOverPopulation.z - 1);

        // foam/bubble particle:
        if (kernelSum > EPSILON && neighbourCount >= minFluidNeighbors)
        {
            // advection: 
            forces = packedData.z / deltaTime * (advectedVelocity / (kernelSum + packedData.w) + advectedAngVelocity - inputVelocities[count]);

            // buoyancy:
            forces -= float4(gravity,0) * inputVelocities[count].w * saturate(kernelSum); // TODO: larger particles should rise faster.
            
        }
        else // spray:
        { 
            // gravity:
            forces += float4(gravity,0);

            // atmospheric drag/aging:
            velocityScale = packedData.y;
            agingScale *= packedData.x * 50;
        }

        // don't change 4th component, as its used to store buoyancy control parameter.
        forces[3] = 0; 
        
        // update particle data:
        attributes.x -= attributes.y * deltaTime * agingScale; 
        //attributes.z += (attributes.y * deltaTime * agingScale) * 0.02; // increase size with age. TODO: maybe do in render shader?
        outputAttributes[aliveCount] = attributes;
        outputColors[aliveCount] = inputColors[count];

        // add forces to velocity:
        outputPositions[aliveCount] = inputPositions[count];
        outputVelocities[aliveCount] = (inputVelocities[count] + forces * deltaTime) * velocityScale;
    }
}

[numthreads(128, 1, 1)]
void Integrate (uint3 id : SV_DispatchThreadID)
{
    unsigned int i = id.x;
    if (i >= dispatch[3]) return;

    outputPositions[i].xyz += outputVelocities[i].xyz * deltaTime;
}

[numthreads(128, 1, 1)]
void Copy (uint3 id : SV_DispatchThreadID)
{
    uint i = id.x;

    if (i == 0)
    {
        dispatch[0] = dispatch[4];
        dispatch[3] = dispatch[7];
    }

    if (i >= dispatch[7]) return;

    outputPositions[i] = inputPositions[i];
    outputVelocities[i] = inputVelocities[i];
    outputColors[i] = inputColors[i];
    outputAttributes[i] = inputAttributes[i];
}

[numthreads(128,1,1)]
void Sort(uint3 id : SV_DispatchThreadID) 
{
    uint i = id.x;

    uint hIndex = i & (groupWidth - 1);
    uint indexLeft = hIndex + (groupHeight + 1) * (i / groupWidth);
    uint rightStepSize = stepIndex == 0 ? groupHeight - 2 * hIndex : (groupHeight + 1) / 2;
    uint indexRight = indexLeft + rightStepSize;

    // Exit if out of bounds
    if (indexRight >= dispatch[3]) return;

    float4 posLeft = outputPositions[indexLeft];
    float4 posRight = outputPositions[indexRight];
    float4 velLeft = outputVelocities[indexLeft];
    float4 velRight = outputVelocities[indexRight];
    float4 colorLeft = outputColors[indexLeft];
    float4 colorRight = outputColors[indexRight];
    float4 attrLeft = outputAttributes[indexLeft];
    float4 attrRight = outputAttributes[indexRight];

    // calculate distance to camera:
    float distLeft = dot(posLeft.xyz, sortAxis.xyz);
    float distRight = dot(posRight.xyz, sortAxis.xyz);

    // Swap entries if order is incorrect
    if (distLeft < distRight)
    {
        outputPositions[indexLeft] = posRight;
        outputPositions[indexRight] = posLeft;
        outputVelocities[indexLeft] = velRight;
        outputVelocities[indexRight] = velLeft;
        outputColors[indexLeft] = colorRight;
        outputColors[indexRight] = colorLeft;
        outputAttributes[indexLeft] = attrRight;
        outputAttributes[indexRight] = attrLeft;
    }
}

[numthreads(128, 1, 1)]
void ClearMesh (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;
    if (i >= maxFoamParticles) return;

    indices.Store((i*6)<<2, 0);
    indices.Store((i*6+1)<<2, 0);
    indices.Store((i*6+2)<<2, 0);

    indices.Store((i*6+3)<<2, 0);
    indices.Store((i*6+4)<<2, 0);
    indices.Store((i*6+5)<<2, 0);
}

[numthreads(128, 1, 1)]
void BuildMesh (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;
    if (i >= dispatch[3]) return;

    // <<2 = multiply by 4 to get byte address, since a float/int is 4 bytes in size.
    
    // particle data is the same for all 4 vertices:
    for (uint v = i*4; v < i*4 + 4; ++v)
    {
        int base = v*19;
        
        // pos
        vertices.Store4(base<<2, asuint(float4(inputPositions[i].xyz, 1)));

        // color:
        vertices.Store4((base+7)<<2, asuint( inputColors[i] ));
        
        // velocity and attributes
        vertices.Store4((base+11)<<2, asuint( float4(inputVelocities[i].xyz, inputPositions[i].w)));
        vertices.Store4((base+15)<<2, asuint( inputAttributes[i] ));
    }

    //different offset for each vertex:
    int base = i*4;
    vertices.Store3((base*19 + 4)<<2, asuint(float3(1,1,0)));
    vertices.Store3(((base+1)*19 + 4)<<2, asuint(float3(-1,1,0)));
    vertices.Store3(((base+2)*19 + 4)<<2, asuint(float3(-1,-1,0)));
    vertices.Store3(((base+3)*19 + 4)<<2, asuint(float3(1,-1,0)));

    // indices:
    indices.Store((i*6)<<2, asuint(i*4+2));
    indices.Store((i*6+1)<<2, asuint(i*4+1));
    indices.Store((i*6+2)<<2, asuint(i*4));

    indices.Store((i*6+3)<<2, asuint(i*4+3));
    indices.Store((i*6+4)<<2, asuint(i*4+2));
    indices.Store((i*6+5)<<2, asuint(i*4));
}
